{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# REINFORCE in Sonnet\n",
    "\n",
    "This notebook implements a basic reinforce algorithm a.k.a. policy gradient for CartPole env.\n",
    "\n",
    "It has been deliberately written to be as simple and human-readable.\n",
    "\n",
    "Authors: [Practical_RL](https://github.com/yandexdataschool/Practical_RL) course team"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The notebook assumes that you have [openai gym](https://github.com/openai/gym) installed.\n",
    "\n",
    "In case you're running on a server, [use xvfb](https://github.com/openai/gym#rendering-on-a-server)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2017-04-08 03:25:37,640] Making new env: CartPole-v0\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7ff60d8ed550>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEACAYAAACwB81wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAE9RJREFUeJzt3X+s3fV93/Hni18jCSWWRecY49SogIJbMpMNMy1DnDbE\nNdMGiSYFWqViW5Sh0SYo07aYaA13mcScSGH5Y0rUNqR1u8SNlTYIqkIxjCORKYH8wMHEuOANr5hh\nk+UHgaKodnjvj/u1OVwb33PvPede+3yeD+nIn+/n+/me7+cjXb/O53x/nG+qCknS5DtlqTsgSVoc\nBr4kNcLAl6RGGPiS1AgDX5IaYeBLUiPGEvhJNibZneSpJB8dxz4kSXOTUV+Hn+RU4K+Aq4BngW8C\nv15VT4x0R5KkORnHDH89sKeq9lbVQeBPgGvHsB9J0hyMI/BXAc8MLO/r6iRJS2gcge9vNUjSCei0\nMbzns8DqgeXVTM/yj0jih4IkzUNVZb7bjmOG/y3gwiRrkpwBXAfcNbNRVU3s69Zbb13yPjg+x9fi\n+CZ5bFULnyePfIZfVYeS/Dbwl8CpwB3lFTqStOTGcUiHqroHuGcc7y1Jmh/vtB2DXq+31F0YK8d3\ncpvk8U3y2EZh5DdeDbXTpJZiv5J0MktCnWAnbSVJJyADX5IaYeBLUiMMfElqhIEvSY0w8CWpEQa+\nJDXCwJekRhj4ktQIA1+SGmHgS1IjDHxJaoSBL0mNMPAlqREGviQ1wsCXpEYY+JLUiAU90zbJXuAn\nwM+Ag1W1Psly4MvALwB7gfdV1Y8X2E9J0gItdIZfQK+qLq2q9V3dJmB7VV0EPNAtS5KW2CgO6cx8\nvuI1wJauvAV4zwj2IUlaoFHM8O9P8q0kH+zqVlTVga58AFixwH1IkkZgQcfwgXdW1XNJfh7YnmT3\n4MqqqiS1wH1IkkZgQYFfVc91/34/yVeB9cCBJG+pqv1JVgLPH2vbqampI+Ver0ev11tIVyRp4vT7\nffr9/sjeL1Xzm4AneSNwalW9mORNwH3AfwKuAn5QVZ9MsglYVlWbZmxb892vJLUqCVU187zp8Nsv\nIPDPB77aLZ4GfLGq/kt3WeY24K28zmWZBr4kzd2SBf5CGPiSNHcLDXzvtJWkRhj4ktQIA1+SGmHg\nS1IjDHxJaoSBL0mNMPAlqREGviQ1wsCXpEYY+JLUCANfkhph4EtSIwx8SWqEgS9JjTDwJakRBr4k\nNcLAl6RGGPiS1AgDX5IaMWvgJ/lCkgNJdg7ULU+yPcmTSe5Lsmxg3S1JnkqyO8mGcXVckjQ3w8zw\n/wDYOKNuE7C9qi4CHuiWSbIWuA5Y223z2SR+i5CkE8CsYVxVDwE/mlF9DbClK28B3tOVrwW2VtXB\nqtoL7AHWj6arkqSFmO/se0VVHejKB4AVXflcYN9Au33AqnnuQ5I0Qgs+3FJVBdTxmix0H5KkhTtt\nntsdSPKWqtqfZCXwfFf/LLB6oN15Xd1RpqamjpR7vR69Xm+eXZGkydTv9+n3+yN7v0xP0GdplKwB\n7q6qS7rlTwE/qKpPJtkELKuqTd1J2y8xfdx+FXA/cEHN2EmSmVWSpFkkoaoy3+1nneEn2QpcCZyT\n5Bng48BmYFuSDwB7gfcBVNWuJNuAXcAh4CaTXZJODEPN8Ee+U2f4kjRnC53he428JDXCwJekRhj4\nktQIA1+SGmHgS1IjDHxJaoSBL0mNMPAlqREGviQ1wsCXpEYY+JLUCANfkhph4EtSIwx8SWqEgS9J\njTDwJakRBr4kNcLAl6RGGPiS1IhZAz/JF5IcSLJzoG4qyb4kj3avqwfW3ZLkqSS7k2wYV8clSXMz\n60PMk1wBvAT8UVVd0tXdCrxYVbfPaLsW+BJwGbAKuB+4qKpemdHOh5hL0hyN/SHmVfUQ8KNj7fsY\nddcCW6vqYFXtBfYA6+fbOUnS6CzkGP6Hknw3yR1JlnV15wL7BtrsY3qmL0laYvMN/M8B5wPrgOeA\nTx+nrcduJOkEcNp8Nqqq5w+Xk3weuLtbfBZYPdD0vK7uKFNTU0fKvV6PXq83n65I0sTq9/v0+/2R\nvd+sJ20BkqwB7h44abuyqp7ryh8BLquq3xg4abueV0/aXjDzDK0nbSVp7hZ60nbWGX6SrcCVwDlJ\nngFuBXpJ1jF9uOZp4EaAqtqVZBuwCzgE3GSyS9KJYagZ/sh36gxfkuZs7JdlSpImg4EvSY0w8CWp\nEQa+JDXCwJekRhj4ktQIA1+SGjGvn1aQJs0rPzvI7js3A7D2n//OEvdGGg8DX836X/d9jh/v3bHU\n3ZAWjYd0JKkRBr4kNcLAl6RGGPiS1AgDX5IaYeBLUiMMfElqhIEvSY0w8CWpEQa+JDVi1sBPsjrJ\ng0m+l+TxJB/u6pcn2Z7kyST3JVk2sM0tSZ5KsjvJhnEOQJI0nGFm+AeBj1TVLwH/EPitJBcDm4Dt\nVXUR8EC3TJK1wHXAWmAj8NkkfpOQpCU2axBX1f6q2tGVXwKeAFYB1wBbumZbgPd05WuBrVV1sKr2\nAnuA9SPutyRpjuY0806yBrgUeBhYUVUHulUHgBVd+Vxg38Bm+5j+gJAkLaGhfx45yVnAnwI3V9WL\nSY6sq6pKUsfZ/Kh1U1NTR8q9Xo9erzdsVySpCf1+n36/P7L3S9XxcrprlJwO/DlwT1V9pqvbDfSq\nan+SlcCDVfW2JJsAqmpz1+5e4Naqenjg/WqY/Urj9Hq/h//3//XvLkFvpNkloaoye8tjG+YqnQB3\nALsOh33nLuCGrnwDcOdA/fVJzkhyPnAh8Mh8OyhJGo1hjuG/E3g/8CtJHu1eG4HNwLuTPAn8ardM\nVe0CtgG7gHuAm5zO60T0ixv+zTHr/89D/32ReyItjlmP4VfV13j9D4arXmeb24DbFtAvSdKIeX28\nJDXCwJekRhj4ktQIA1+SGmHgS1IjDHxJaoSBL0mNMPAlqREGviQ1wsCXpEYY+JLUCANfkhph4EtS\nIwx8SWqEgS9JjTDwJakRBr4kNcLAl6RGGPiS1IhZAz/J6iQPJvlekseTfLirn0qyb+DB5lcPbHNL\nkqeS7E6yYZwDkCQNZ9aHmAMHgY9U1Y4kZwHfTrIdKOD2qrp9sHGStcB1wFpgFXB/kouq6pUR912S\nNAezzvCran9V7ejKLwFPMB3kADnGJtcCW6vqYFXtBfYA60fTXUnSfM3pGH6SNcClwDe6qg8l+W6S\nO5Is6+rOBfYNbLaPVz8gJElLZJhDOgB0h3O+AtxcVS8l+RzwiW71fwY+DXzgdTavmRVTU1NHyr1e\nj16vN2xXJKkJ/X6ffr8/svdL1VFZfHSj5HTgz4F7quozx1i/Bri7qi5JsgmgqjZ36+4Fbq2qhwfa\n1zD7lcbt279341F151x8Bb9wxfuXoDfS8SWhqo51KH0ow1ylE+AOYNdg2CdZOdDsvcDOrnwXcH2S\nM5KcD1wIPDLfDkqSRmOYQzrvBN4PPJbk0a7uY8CvJ1nH9OGap4EbAapqV5JtwC7gEHCT03lJWnqz\nBn5VfY1jfxO45zjb3AbctoB+SZJGzDttJakRBr4kNcLAl6RGGPiS1AgDX5IaYeBLUiMMfElqhIEv\nSY0w8CWpEQa+JDXCwFfTVq1/71F1/++Jh5agJ9L4GfiS1AgDX5IaYeBLUiOGeuLVyHfqE680Rg88\n8AA/+clPhmp79sv/m2V/8+RR9X/98xuH3t+73vUuzj777KHbS/O10CdeGfiaOG9/+9vZuXPn7A2B\nf7FxHb/93vVH1f+DG39v6P099thjXHLJJUO3l+ZroYE/9EPMpUn1o79dwbd/fNWR5av+7heXsDfS\n+HgMX03b/9M1/M8fXMtPf/amI6/HXrhiqbsljYWBr6a9dGjZUXV//fLFS9ATafyOG/hJzkzycJId\nSR5PMtXVL0+yPcmTSe5Lsmxgm1uSPJVkd5INY+6/NHJ/55SXl7oL0lgcN/Cr6qfAr1TVOmAdsDHJ\n5cAmYHtVXQQ80C2TZC1wHbAW2Ah8NonfInTCuuCsHZySV15Tt/LMp5eoN9J4zXrStqoOT3fOAE4H\nCrgGuLKr3wL0mQ79a4GtVXUQ2JtkD7Ae+MZouy2NzoYVW3j50Nl89He388z3X+DnTvvhUndJGotZ\nA7+boX8H+EXgv1XVI0lWVNWBrskBYEVXPpfXhvs+YNWx3vfGG2+cd6el43n22WeHbvuH9+7gD+/d\n8Zq6/XPc3yc+8QmWL18+x62kxTfMDP8VYF2SNwNfTfLLM9ZXkuNdVH/MdStXrjxS7vV69Hq9oTos\nzebrX/86P/zh4s3SP/7xj3sdvsai3+/T7/dH9n5zuvEqye8ALwMfBHpVtT/JSuDBqnpbkk0AVbW5\na38vcGtVPTzjfbzxSmMzlxuvRsEbr7RYFnrj1WxX6Zxz+AqcJG8A3g08AdwF3NA1uwG4syvfBVyf\n5Iwk5wMXAo/Mt3OSpNGZ7ZDOSmBLklOZ/nD4clX9RZJvANuSfADYC7wPoKp2JdkG7AIOATc5lZek\nE8NxA7+qdgLvOEb9D4Grjt4Cquo24LaR9E6SNDJeIy9JjTDwJakRBr4kNcKfR9bEuf3223nhhRcW\nbX9vfetbF21f0kL4ABRJOkmM9Tp8SdLkMPAlqREGviQ1wsCXpEYY+JLUCANfkhph4EtSIwx8SWqE\ngS9JjTDwJakRBr4kNcLAl6RGGPiS1IjZHmJ+ZpKHk+xI8niSqa5+Ksm+JI92r6sHtrklyVNJdifZ\nMOb+S5KGNOvPIyd5Y1W9nOQ04GvAzcBG4MWqun1G27XAl4DLgFXA/cBFVfXKjHb+PLIkzdHYfx65\nql7uimcApwOHk/pYO70W2FpVB6tqL7AHWD/fzkmSRmfWwE9ySpIdwAHgvqp6pFv1oSTfTXJHkmVd\n3bnAvoHN9zE905ckLbFhZvivVNU64Dzg8iS/BHwOOB9YBzwHfPp4bzGKjkqSFmboZ9pW1QtJHgQ2\nVtWRgE/yeeDubvFZYPXAZud1dUeZmpo6Uu71evR6vaE7LUkt6Pf79Pv9kb3fcU/aJjkHOFRVP07y\nBuAvgc3Ad6pqf9fmI8BlVfUbAydt1/PqSdsLZp6h9aStJM3dQk/azjbDXwlsSXIq04d/vlxVf5Hk\nj5KsY/pwzdPAjQBVtSvJNmAXcAi4yWSXpBPDrJdljmWnzvAlac7GflmmJGkyGPiS1AgDX5IaYeBL\nUiMMfElqhIEvSY0w8CWpEQa+JDXCwJekRhj4ktQIA1+SGmHgS1IjDHxJaoSBL0mNMPAlqREGviQ1\nwsCXpEYY+JLUCANfkhoxVOAnOTXJo0nu7paXJ9me5Mkk9yVZNtD2liRPJdmdZMO4Oi5JmpthZ/g3\nA7uAw08e3wRsr6qLgAe6ZZKsBa4D1gIbgc8mae5bRL/fX+oujJXjO7lN8vgmeWyjMGsYJzkP+CfA\n54HDT0u/BtjSlbcA7+nK1wJbq+pgVe0F9gDrR9nhk8Gk/9E5vpPbJI9vksc2CsPMvv8r8O+BVwbq\nVlTVga58AFjRlc8F9g202wesWmgnJUkLd9zAT/JPgeer6lFend2/RlUVrx7qOWaT+XdPkjQqmc7r\n11mZ3Ab8JnAIOBM4G/gz4DKgV1X7k6wEHqyqtyXZBFBVm7vt7wVuraqHZ7yvHwKSNA9VdczJ9zCO\nG/ivaZhcCfy7qvpnST4F/KCqPtmF/LKq2tSdtP0S08ftVwH3AxfUsDuRJI3NaXNsfzi4NwPbknwA\n2Au8D6CqdiXZxvQVPYeAmwx7SToxDD3DlySd3Bb9GvkkG7ubsp5K8tHF3v8oJPlCkgNJdg7UTcTN\naElWJ3kwyfeSPJ7kw139pIzvzCQPJ9nRjW+qq5+I8R02yTdLJtmb5LFufI90dRMxviTLknwlyRNJ\ndiW5fKRjq6pFewGnMn1t/hrgdGAHcPFi9mFE47gCuBTYOVD3KeA/dOWPApu78tpunKd3494DnLLU\nYzjO2N4CrOvKZwF/BVw8KePr+vzG7t/TgG8Al0/S+Lp+/1vgi8Bdk/T32fX5aWD5jLqJGB/T9zX9\nq4G/zzePcmyLPcNfD+ypqr1VdRD4E6Zv1jqpVNVDwI9mVE/EzWhVtb+qdnTll4AnmD4BPxHjA6iq\nl7viGUz/ZykmaHyN3Cw580qVk358Sd4MXFFVXwCoqkNV9QIjHNtiB/4q4JmB5Um6MWvibkZLsobp\nbzIPM0HjS3JKkh1Mj+O+qnqECRofk3+zZAH3J/lWkg92dZMwvvOB7yf5gyTfSfL7Sd7ECMe22IHf\nxBnimv6+dVLfjJbkLOBPgZur6sXBdSf7+KrqlapaB5wHXJ7kl2esP2nH18jNku+sqkuBq4HfSnLF\n4MqTeHynAe8APltV7wD+hu53yg5b6NgWO/CfBVYPLK/mtZ9QJ7MDSd4C0N2M9nxXP3PM53V1J6wk\npzMd9n9cVXd21RMzvsO6r8sPAr/G5IzvHwHXJHka2Ar8apI/ZnLGR1U91/37feCrTB/GmITx7QP2\nVdU3u+WvMP0BsH9UY1vswP8WcGGSNUnOYPqXNe9a5D6My13ADV35BuDOgfrrk5yR5HzgQuCRJejf\nUJIEuAPYVVWfGVg1KeM75/BVDkneALyb6fMUEzG+qvpYVa2uqvOB64H/UVW/yYSML8kbk/xcV34T\nsAHYyQSMr6r2A88kuairugr4HnA3oxrbEpyFvprpKz/2ALcs9VnxeY5hK/B/gb9l+pzEvwSWM31n\n8ZPAfUzffXy4/ce68e4Gfm2p+z/L2P4x08d+dwCPdq+NEzS+S4DvAN9lOij+Y1c/EeObMdYrefUq\nnYkYH9PHuXd0r8cPZ8gEje/vAd/s/j7/jOmrdEY2Nm+8kqRGNPdwEklqlYEvSY0w8CWpEQa+JDXC\nwJekRhj4ktQIA1+SGmHgS1Ij/j+Vb4c3YZxxLQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7ff6146dd450>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import gym\n",
    "import numpy as np, pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "env = gym.make(\"CartPole-v0\")\n",
    "\n",
    "#gym compatibility: unwrap TimeLimit\n",
    "if hasattr(env,'env'):\n",
    "    env=env.env\n",
    "\n",
    "env.reset()\n",
    "n_actions = env.action_space.n\n",
    "state_dim = env.observation_space.shape\n",
    "\n",
    "plt.imshow(env.render(\"rgb_array\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building the network for REINFORCE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For REINFORCE algorithm, we'll need a model that predicts action probabilities given states."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import sonnet as snt\n",
    "\n",
    "\n",
    "#create input variables. We only need <s,a,R> for REINFORCE\n",
    "states = tf.placeholder('float32',(None,)+state_dim,name=\"states\")\n",
    "actions = tf.placeholder('int32',name=\"action_ids\")\n",
    "cumulative_rewards = tf.placeholder('float32', name=\"cumulative_returns\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "\n",
    "def make_network(inputs):\n",
    "    \n",
    "    lin1 = snt.Linear(output_size=100)(inputs)\n",
    "    elu1 = tf.nn.elu(lin1)\n",
    "    \n",
    "    logits = snt.Linear(output_size=n_actions)(elu1)\n",
    "    policy = tf.nn.softmax(logits)\n",
    "    log_policy = tf.nn.log_softmax(logits)\n",
    "    \n",
    "    return logits, policy, log_policy\n",
    "\n",
    "net = snt.Module(make_network,name=\"policy_network\")\n",
    "\n",
    "logits,policy,log_policy = net(states)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#utility function to pick action in one given state\n",
    "get_action_proba = lambda s: policy.eval({states:[s]})[0] "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Loss function and updates\n",
    "\n",
    "We now need to define objective and update over policy gradient.\n",
    "\n",
    "The objective function can be defined thusly:\n",
    "\n",
    "$$ J \\approx \\sum  _i log \\pi_\\theta (a_i | s_i) \\cdot R(s_i,a_i) $$\n",
    "\n",
    "When you compute gradient of that function over network weights $ \\theta $, it will become exactly the policy gradient.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#REINFORCE objective function\n",
    "actions_1hot = tf.one_hot(actions,n_actions)\n",
    "\n",
    "log_pi_a = -tf.nn.softmax_cross_entropy_with_logits(logits=logits,labels=actions_1hot)\n",
    "\n",
    "J = tf.reduce_mean(log_pi_a * cumulative_rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#regularize with entropy\n",
    "entropy = -tf.reduce_mean(policy*log_policy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#all network weights\n",
    "all_weights = net.get_variables()\n",
    "\n",
    "#weight updates. maximizing J is same as minimizing -J\n",
    "loss = -J -0.1*entropy\n",
    "update = tf.train.AdamOptimizer().minimize(loss,var_list=all_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computing cumulative rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_cumulative_rewards(rewards, #rewards at each step\n",
    "                           gamma = 0.99 #discount for reward\n",
    "                           ):\n",
    "    \"\"\"\n",
    "    take a list of immediate rewards r(s,a) for the whole session \n",
    "    compute cumulative rewards R(s,a) (a.k.a. G(s,a) in Sutton '16)\n",
    "    R_t = r_t + gamma*r_{t+1} + gamma^2*r_{t+2} + ...\n",
    "    \n",
    "    The simple way to compute cumulative rewards is to iterate from last to first time tick\n",
    "    and compute R_t = r_t + gamma*R_{t+1} recurrently\n",
    "    \n",
    "    You must return an array/list of cumulative rewards with as many elements as in the initial rewards.\n",
    "    \"\"\"\n",
    "    \n",
    "    cumulative_rewards = []\n",
    "    R = 0\n",
    "    \n",
    "    for r in rewards[::-1]:\n",
    "        R = r + gamma*R\n",
    "        cumulative_rewards.insert(0,R)\n",
    "        \n",
    "    return cumulative_rewards\n",
    "    \n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "looks good!\n"
     ]
    }
   ],
   "source": [
    "assert len(get_cumulative_rewards(range(100))) == 100\n",
    "assert np.allclose(get_cumulative_rewards([0,0,1,0,0,1,0],gamma=0.9),[1.40049, 1.5561, 1.729, 0.81, 0.9, 1.0, 0.0])\n",
    "assert np.allclose(get_cumulative_rewards([0,0,1,-2,3,-4,0],gamma=0.5), [0.0625, 0.125, 0.25, -1.5, 1.0, -4.0, 0.0])\n",
    "assert np.allclose(get_cumulative_rewards([0,0,1,2,3,4,0],gamma=0), [0, 0, 1, 2, 3, 4, 0])\n",
    "print(\"looks good!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train_step(_states,_actions,_rewards):\n",
    "    \"\"\"given full session, trains agent with policy gradient\"\"\"\n",
    "    _cumulative_rewards = get_cumulative_rewards(_rewards)\n",
    "    update.run({states:_states,actions:_actions,cumulative_rewards:_cumulative_rewards})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Playing the game"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def generate_session(t_max=1000):\n",
    "    \"\"\"play env with REINFORCE agent and train at the session end\"\"\"\n",
    "    \n",
    "    #arrays to record session\n",
    "    states,actions,rewards = [],[],[]\n",
    "    \n",
    "    s = env.reset()\n",
    "    \n",
    "    for t in range(t_max):\n",
    "        \n",
    "        #action probabilities array aka pi(a|s)\n",
    "        action_probas = get_action_proba(s)\n",
    "        \n",
    "        a = np.random.choice(n_actions,p=action_probas)\n",
    "        \n",
    "        new_s,r,done,info = env.step(a)\n",
    "        \n",
    "        #record session history to train later\n",
    "        states.append(s)\n",
    "        actions.append(a)\n",
    "        rewards.append(r)\n",
    "        \n",
    "        s = new_s\n",
    "        if done: break\n",
    "            \n",
    "    train_step(states,actions,rewards)\n",
    "            \n",
    "    return sum(rewards)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean reward:27.590\n",
      "mean reward:70.340\n",
      "mean reward:129.570\n",
      "mean reward:188.330\n",
      "mean reward:211.530\n",
      "mean reward:240.490\n",
      "mean reward:235.760\n",
      "mean reward:218.030\n",
      "mean reward:258.470\n",
      "mean reward:184.760\n",
      "mean reward:298.920\n",
      "mean reward:507.360\n",
      "You Win!\n"
     ]
    }
   ],
   "source": [
    "s = tf.InteractiveSession()\n",
    "s.run(tf.global_variables_initializer())\n",
    "\n",
    "for i in range(100):\n",
    "    \n",
    "    rewards = [generate_session() for _ in range(100)] #generate new sessions\n",
    "    \n",
    "    print (\"mean reward:%.3f\"%(np.mean(rewards)))\n",
    "\n",
    "    if np.mean(rewards) > 300:\n",
    "        print (\"You Win!\")\n",
    "        break\n",
    "        \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Results & video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2017-04-08 03:29:10,315] Making new env: CartPole-v0\n",
      "[2017-04-08 03:29:10,324] DEPRECATION WARNING: env.spec.timestep_limit has been deprecated. Replace your call to `env.spec.timestep_limit` with `env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps')`. This change was made 12/28/2016 and is included in version 0.7.0\n",
      "[2017-04-08 03:29:10,329] Clearing 6 monitor files from previous run (because force=True was provided)\n",
      "[2017-04-08 03:29:10,336] Starting new video recorder writing to /home/jheuristic/Downloads/sonnet/sonnet/examples/videos/openaigym.video.0.14221.video000000.mp4\n",
      "[2017-04-08 03:29:16,834] Starting new video recorder writing to /home/jheuristic/Downloads/sonnet/sonnet/examples/videos/openaigym.video.0.14221.video000001.mp4\n",
      "[2017-04-08 03:29:23,689] Starting new video recorder writing to /home/jheuristic/Downloads/sonnet/sonnet/examples/videos/openaigym.video.0.14221.video000008.mp4\n"
     ]
    }
   ],
   "source": [
    "#record sessions\n",
    "import gym.wrappers\n",
    "env = gym.wrappers.Monitor(gym.make(\"CartPole-v0\"),directory=\"videos\",force=True)\n",
    "sessions = [generate_session() for _ in range(100)]\n",
    "env.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#show video\n",
    "from IPython.display import HTML\n",
    "import os\n",
    "\n",
    "video_names = list(filter(lambda s:s.endswith(\".mp4\"),os.listdir(\"./videos/\")))\n",
    "\n",
    "HTML(\"\"\"\n",
    "<video width=\"640\" height=\"480\" controls>\n",
    "  <source src=\"{}\" type=\"video/mp4\">\n",
    "</video>\n",
    "\"\"\".format(\"./videos/\"+video_names[-1])) #this may or may not be _last_ video. Try other indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#That's all, thank you for your attention!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [Root]",
   "language": "python",
   "name": "Python [Root]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
